import json

import math
import numpy as np

from allennlp.predictors import Predictor

from attack_model.candidates import IBPCandidates
from config import Config
from tools.utils import read_json
from .attacker import BlackBoxAttacker


class PMIAttack(BlackBoxAttacker):
    def __init__(self, cf: Config, predictor: Predictor):
        super(PMIAttack, self).__init__(cf, predictor)
        self.synonym_candidate = IBPCandidates(self.supported_postag, cf.p_synonym)
        self.class2prob, self.token2prob, self.class2token2prob, self.freq_filtered_tokens = self.init_pmi_statistics()

    def get_labels(self):
        return self.class2prob.keys()

    def get_victim_substitute_pair(self, text):

        H = []

        for i, (word, tag) in enumerate(zip(text['sentence'], text['tag'])):
            # print(f'============{word}=============')
            synonyms = self.synonym_candidate.candidate_set(word, tag)
            scores, candidates = [], []
            for w in synonyms:
                if w.lower() == word.lower():
                    continue
                if w not in self.freq_filtered_tokens:
                    continue
                s = self.get_score(word, w, text['label'], self.get_labels())
                scores.append(s)
                candidates.append(w)
                # print(w, s, [self.pmi(w, l) for l in ['1', '2', '3', '4', ]])
            if len(scores) == 0:
                continue
            best_synonym = candidates[np.argmax(scores)]
            best_score = np.max(scores)
            H.append((i, best_synonym, best_score))

        attack_num = self.attack_num(len(text['sentence']))
        H.sort(key=lambda i: i[2], reverse=True)
        H = H[:attack_num]
        H = [(h[0], h[1]) for h in H]
        return H

    def get_score(self, src_w, tgt_w, gold_label, labels):

        delta_pmi = []
        for l in labels:
            if l == gold_label:
                delta_pmi.append(self.pmi(src_w, l) - self.pmi(tgt_w, l))
            else:
                delta_pmi.append(self.pmi(tgt_w, l) - self.pmi(src_w, l))
        return np.average(delta_pmi)

    def init_pmi_statistics(self):
        class2prob = read_json(f'pmi/{self.cf.dataset}/class2prob.json')
        token2prob = read_json(f'pmi/{self.cf.dataset}/token2prob,json')
        class2token2prob = read_json(f'pmi/{self.cf.dataset}/class2token2prob.json')
        freq_filtered_tokens = read_json(f'pmi/{self.cf.dataset}/freq_filtered_tokens.json')
        return class2prob, token2prob, class2token2prob, freq_filtered_tokens

    def pmi(self, token, label):
        if token not in self.token2prob:
            return 0

        pt = self.token2prob[token]
        pc = self.class2prob[label]
        if token not in self.class2token2prob[label]:
            return 0
        else:
            ptc = self.class2token2prob[label][token]
            pmi = math.log2(ptc / (pt * pc))
            return pmi
